# @package _global_
defaults:
  - override /trainer: default # choose trainer from 'configs/trainer/'
  - override /model: null
  - override /datamodule: thepile
  - override /optimizer: adamw-apex  # slight speedup (1-2%) over Pytorch AdamW
  - override /scheduler: cosine-warmup-timm
  - override /callbacks: [default, norm-monitor]
  - override /metrics: [perplexity, num-tokens]
  - override /logger: wandb

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

task:
  _target_: src.tasks.seq.SequenceLMModel

seed: 1111

trainer:
  accelerator: gpu
  devices: 8
  num_nodes: 1
  accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${datamodule.batch_size} * ${trainer.num_nodes}}}
  max_steps: 800000
  val_check_interval: ${eval:2000 * ${.accumulate_grad_batches}}
  check_val_every_n_epoch: null  # We don't care about epoch boundary
  precision: bf16
  gradient_clip_val: 1.0
  strategy: null

datamodule:
  batch_size: 16  # Per GPU
  batch_size_eval: ${.batch_size}  # Fused dense only support batch size at most 64k
  max_length: 2048
  fault_tolerant: True
  ddp: ${eval:"${trainer.devices} > 1"}

train:
  gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"}
  global_batch_size: 256
  optimizer:
    lr: 6e-4
    weight_decay: 0.1
  optimizer_param_grouping:
    bias_weight_decay: False
    normalization_weight_decay: False
  scheduler:
    t_in_epochs: False
    t_initial: 600000
    warmup_lr_init: 1e-6
    warmup_t: ${eval:0.01 * ${trainer.max_steps}}
    lr_min: ${eval:0.1 * ${train.optimizer.lr}}
  loss_fn:
    # This is faster and uses less memory than torch.nn.CrossEntropyLoss.
    # It's also more numerically stable if we're using DeepSpeed 16 bits.
    _target_: flash_attn.losses.cross_entropy.CrossEntropyLoss
    inplace_backward: True  # to save memory

eval:
  log_on_step: True  # 1 training epoch takes too long, we want to see metrics per train step

callbacks:
  model_checkpoint:
    monitor: val/loss
    mode: min
    save_top_k: 3
    save_last: True
    every_n_train_steps: 1000
    dirpath: ${work_dir}/checkpoints/${oc.select:name,''}
    filename: step_{step}
    auto_insert_metric_name: False
  model_checkpoint_progress:
    _target_: src.callbacks.model_checkpoint.ModelCheckpointMine
    # fault_tolerant: True  # The .pl_auto_save.ckpt doesn't get saved by all workers
    every_n_train_steps: 50000
    save_last: False
    save_top_k: -1  # Save all the checkpoints
    dirpath: ${..model_checkpoint.dirpath}
    filename: progress_step_{step}
    auto_insert_metric_name: False
  early_stopping: null

